# Copyright (c) 2024 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from msit_llm.common.log import logger

import torch
import torch.nn.functional as function
import torch.distributed as dist

HOOK_OPS = {}


def add_torch_ops():
    torch_hooks = {
        function: [
            "threshold",
            "threshold_",
            "relu",
            "relu_",
            "glu",
            "hardtanh",
            "hardtanh_",
            "relu6",
            "elu",
            "elu_",
            "selu",
            "selu_",
            "celu",
            "celu_",
            "leaky_relu",
            "leaky_relu_",
            "prelu",
            "rrelu",
            "rrelu_",
            "logsigmoid",
            "gelu",
            "hardshrink",
            "tanhshrink",
            "softsign",
            "softplus",
            "softmin",
            "softmax",
            "gumbel_softmax",
            "log_softmax",
            "softshrink",
            "tanh",
            "sigmoid",
            "hardsigmoid",
            "silu",
            "hardswish",
            "pixel_shuffle",
            "pixel_unshuffle",
            "channel_shuffle",
            "upsample_nearest",
            "upsample_bilinear",
            "grid_sample",
            "affine_grid",
            "pdist",
            "one_hot",
        ],
        torch: [
            "abs",
            "absolute",
            "acos",
            "arccos",
            "acosh",
            "arccosh",
            "add",
            "addcdiv",
            "addcmul",
            "angle",
            "asin",
            "arcsin",
            "asinh",
            "arcsinh",
            "atan",
            "arctan",
            "atanh",
            "arctanh",
            "atan2",
            "arctan2",
            "bitwise_not",
            "bitwise_and",
            "bitwise_or",
            "bitwise_xor",
            "bitwise_left_shift",
            "bitwise_right_shift",
            "ceil",
            "clamp",
            "clip",
            "conj_physical",
            "copysign",
            "cos",
            "cosh",
            "deg2rad",
            "div",
            "divide",
            "digamma",
            "erf",
            "erfc",
            "erfinv",
            "exp",
            "exp2",
            "expm1",
            "fake_quantize_per_channel_affine",
            "fake_quantize_per_tensor_affine",
            "fix",
            "float_power",
            "floor",
            "floor_divide",
            "fmod",
            "frac",
            "frexp",
            "gradient",
            "imag",
            "ldexp",
            "lerp",
            "lgamma",
            "log",
            "log10",
            "log1p",
            "log2",
            "logaddexp",
            "logaddexp2",
            "logical_and",
            "logical_not",
            "logical_or",
            "logical_xor",
            "logit",
            "hypot",
            "i0",
            "igamma",
            "igammac",
            "mul",
            "multiply",
            "mvlgamma",
            "nan_to_num",
            "neg",
            "negative",
            "nextafter",
            "polygamma",
            "positive",
            "pow",
            "quantized_batch_norm",
            "quantized_max_pool1d",
            "quantized_max_pool2d",
            "rad2deg",
            "real",
            "reciprocal",
            "remainder",
            "round",
            "rsqrt",
            "sigmoid",
            "sign",
            "sgn",
            "signbit",
            "sin",
            "sinc",
            "sinh",
            "softmax",
            "sqrt",
            "square",
            "sub",
            "subtract",
            "tan",
            "tanh",
            "true_divide",
            "trunc",
            "xlogy",
            "addbmm",
            "addmm",
            "addmv",
            "addr",
            "baddbmm",
            "bmm",
            "chain_matmul",
            "cholesky",
            "cholesky_inverse",
            "cholesky_solve",
            "dot",
            "geqrf",
            "ger",
            "inner",
            "inverse",
            "det",
            "logdet",
            "slogdet",
            "lu",
            "lu_solve",
            "lu_unpack",
            "matmul",
            "matrix_power",
            "matrix_exp",
            "mm",
            "mv",
            "orgqr",
            "ormqr",
            "outer",
            "pinverse",
            "qr",
            "svd",
            "svd_lowrank",
            "pca_lowrank",
            "lobpcg",
            "trapz",
            "trapezoid",
            "cumulative_trapezoid",
            "triangular_solve",
            "vdot",
        ],
        dist: [
            "send",
            "recv",
            "broadcast",
            "all_reduce",
            "reduce",
            "all_gather",
            "gather",
            "isend",
            "irecv",
            "scatter",
            "reduce_scatter",
        ],
    }
    HOOK_OPS.update(torch_hooks)


def add_torch_npu_ops():
    try:
        import torch_npu
    except ImportError:
        logger.warning("torch_npu is not installed.")
        return

    torch_npu_hooks = [
        "fast_gelu",
        "npu_mish",
        "npu_scaled_masked_softmax",
        "npu_dropout_with_add_softmax",
        "npu_random_choice_with_mask",
        "npu_roi_align",
        "npu_roi_alignbk",
        "npu_all_gather_base_mm",
    ]
    HOOK_OPS[torch_npu] = torch_npu_hooks
